{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "class MNISTLoader():\n",
    "    def __init__(self):\n",
    "        mnist = tf.keras.datasets.mnist\n",
    "        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()\n",
    "        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道\n",
    "        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]\n",
    "        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]\n",
    "        self.train_label = self.train_label.astype(np.int32)    # [60000]\n",
    "        self.test_label = self.test_label.astype(np.int32)      # [10000]\n",
    "        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]\n",
    "\n",
    "    def get_batch(self, batch_size):\n",
    "        # 从数据集中随机取出batch_size个元素并返回\n",
    "        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)\n",
    "        return self.train_data[index, :], self.train_label[index]\n",
    "    \n",
    "    \n",
    "class MLP(tf.keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维（batch_size）以外的维度展平\n",
    "        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)\n",
    "        self.dense2 = tf.keras.layers.Dense(units=10)\n",
    "\n",
    "    def call(self, inputs):         # [batch_size, 28, 28, 1]\n",
    "        x = self.flatten(inputs)    # [batch_size, 784]\n",
    "        x = self.dense1(x)          # [batch_size, 100]\n",
    "        x = self.dense2(x)          # [batch_size, 10]\n",
    "        output = tf.nn.softmax(x)\n",
    "        return output\n",
    "    \n",
    "    \n",
    "###定义超参\n",
    "num_epochs = 5\n",
    "batch_size = 50\n",
    "learning_rate = 0.001\n",
    "\n",
    "\n",
    "#实例化模型和数据读取类，并实例化一个 tf.keras.optimizer 的优化器（这里使用常用的 Adam 优化器）：\n",
    "model = MLP()\n",
    "data_loader = MNISTLoader()\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
    "num_batches = int(data_loader.num_train_data // batch_size * num_epochs)\n",
    "#迭代\n",
    "for batch_index in range(num_batches):\n",
    "    X, y = data_loader.get_batch(batch_size)\n",
    "    with tf.GradientTape() as tape:\n",
    "        y_pred = model(X)\n",
    "        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)\n",
    "        loss = tf.reduce_mean(loss)\n",
    "        print(\"batch %d: loss %f\" % (batch_index, loss.numpy()))\n",
    "    grads = tape.gradient(loss, model.variables)\n",
    "    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))\n",
    "\n",
    "    \n",
    "loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)\n",
    "\n",
    "sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "num_batches = int(data_loader.num_test_data // batch_size)\n",
    "for batch_index in range(num_batches):\n",
    "    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size\n",
    "    y_pred = model.predict(data_loader.test_data[start_index: end_index])\n",
    "    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)\n",
    "print(\"test accuracy: %f\" % sparse_categorical_accuracy.result())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
